"""Created on Wed Apr  1 15:58:30 2020.

Modified Causal Forest - Python implementation

Can be used under Creative Commons Licence CC BY-SA
Michael Lechner, SEW, University of St. Gallen, Switzerland

Version: 0.2.0 dev


"""
import sys

import pandas as pd
import numpy as np

from mcf import general_purpose as gp
from mcf import mcf_functions as mcf

TRAIN_MCF = True       # Train the forest; def: True
PREDICT_MCF = True     # Estimate effects; def: True
SAVE_FOREST = False    # Save forest for predict. w/o reestimation def:False
FOREST_FILES = None  # File name for all information needed for prediction
#   If a name is specified, then files with an *.csv, *.pickle, *ps.npy, *d.npy
#   extensions will be automatically specified.
#   If None, file names will be same as indat with extensions
#   *_savepred.pickle, *_savepred.csv,  *_savepredps.npy, *_savepredd.npy
#                                      These 4 files will be saved in outpfad.

APPLIC_PATH = 'D:/MLechner/Re21Covid'  # NOT passed to MCF
OUTPATH_ALL = APPLIC_PATH + '/out'
#   If this None a */out directory below the current directory is used
#   If specified directory does not exist, it will be created.
DATPATH = APPLIC_PATH + '/data'
#   If a path is None, path of this file is used.

INDATA = 'estimation_data_all_2022_07_12'
# csv for estimation (without extension)
PREDDATA = 'estimation_data_all_2022_07_12'

#   csv extension is added automatically to both file names
#   If preddata is not specified, indata will be used as file name.

#   Variables for estimation


# Time of treatment: December 2020

# 'Treatment_group_num': Massive common support problems!

#  ['Covid_dimension_num', 'KA_dimension_num', 'Treatment_group_num']
# Covid_dimension_num: 1: low, 2: medium, 3: high
# KA_dimension_num	1: low, 2: medium, 3: high
# Treatment_group_num
# 1: Covid_low_KA_low, 2: Covid_low_KA_medium, 3: Covid_low_KA_high,
# 4: Covid_medium_KA_low, 5: Covid_medium_KA_medium, 6: Covid_medium_KA_high
# 7: Covid_high_KA_low, 8: Covid_high_KA_medium, 9: Covid_high_KA_high"

outcome_all = ['Q1_2021_rent_residentiallogp', 'Q1_2021_rent_officelogp',
               'Q1_2021_rent_retaillogp',
               'Q1_2021_sale_residentiallogp', 'Q1_2021_sale_officelogp',
               'Q1_2021_sale_retaillogp'
               ]
treatment_all = ['Covid_dimension_num', 'KA_dimension_num']
specification_all = ['fat', 'lean']

ate, ate_se, ate_name = [], [], []
for idx_treat, treat_name in enumerate(treatment_all):
    for idx_out, out_name in enumerate(outcome_all):
        for idx_spec, spec_name in enumerate(specification_all):
            FS_YES = True          # True: feature selection active
        #                             False: not active (def)
            WEIGHTED = True

            Y_NAME = out_name
            D_NAME = [treat_name]
            add_to_name = treat_name[:5] + '_' + out_name[8:] + '_' + spec_name
            OUTPATH = (OUTPATH_ALL + '_' + add_to_name)
            OUTFILETEXT = INDATA[:3] + add_to_name  # File for text output
            if WEIGHTED:
                OUTFILETEXT += '_WEIGHT'
                OUTPATH += '_WEIGHT'
            Y_TREE_NAME = None
            if spec_name == 'fat':
                X_NAME_ORD = [
                    'apt_share_owned_gov', 'apt_share_owned_others',
                    'apt_share_owned_private_invest',
                    'apt_share_owned_private_person', 'apt_total',
                    'new_apt_total', 'new_non_res_fac_total',
                    'new_res_fac_total', 'avg_apt_per_res_fac', 'avg_apt_size',
                    'avg_people_per_household', 'empl_res_share_commute_out',
                    'empl_wp_share_commute_in', 'empl_res_share_foreign',
                    'empl_res_share_o55', 'empl_res_total', 'empl_wp_total',
                    'ft_empl_wp_share_sector_primary',
                    'ft_empl_wp_share_sector_secondary',
                    'ft_empl_wp_share_sector_tertiary_G_H_L_O_P_Q_S_T_U',
                    'ft_empl_wp_share_sector_tertiary_I_R',
                    'ft_empl_wp_share_sector_tertiary_J_M_N',
                    'ft_empl_wp_share_sector_tertiary_K',
                    'hh_share_1_person', 'hh_share_dinks', 'hh_share_family',
                    'hh_total', 'Q4_2019_rent_resh_share',
                    'Q4_2019_rent_resarea', 'Q4_2019_rent_resc_year',
                    'Q4_2019_rent_offarea', 'Q4_2019_rent_offc_year',
                    'Q4_2019_rent_retarea', 'Q4_2019_rent_retc_year',
                    'HO_occ_index', 'log_income_pp', 'pop_avg_age',
                    'pop_share_age_0_9', 'pop_share_age_10_19',
                    'pop_share_citizen_EU27', 'pop_share_citizen_germany',
                    'pop_share_inflow', 'pop_share_outflow',
                    'population_total', 'area', 'Q4_2019_rent_officelogp',
                    'Q4_2019_rent_residentiallogp', 'Q4_2019_rent_retaillogp',
                    'res_fac_share_13_more_apt',
                    'res_fac_share_detached_house', 'res_fac_share_res_build',
                    'res_fac_total', 'unempl_div_empl_res', 'unempl_res_total',
                    'urbanization', 'vote_share_left', 'vote_share_mid',
                    'vote_share_right',
                    'priv_alle_tech_200'
                    ]
            else:
                X_NAME_ORD = ['Q4_2019_rent_officelogp',
                              'Q4_2019_rent_residentiallogp',
                              'Q4_2019_rent_retaillogp',
                              'Q4_2019_sale_officelogp',
                              'Q4_2019_sale_residentiallogp',
                              'Q4_2019_sale_retaillogp']
            X_NAME_UNORD = []
        #   Identifier
            ID_NAME = []
            CLUSTER_NAME = []  # Variable defining the clusters if used
            W_NAME = ['pop_total']

            X_NAME_ALWAYS_IN_ORD = []      # (not needed for pred)
            X_NAME_ALWAYS_IN_UNORD = []    # (not needed for pred)
        #   are coded as ordered or unordered.
            Z_NAME_LIST = [
                'apt_share_owned_private_invest',
                'avg_apt_per_res_fac',
                'avg_apt_size',
                'ft_empl_wp_share_sector_tertiary_I_R',
                'ft_empl_wp_share_sector_tertiary_K',
                'HO_occ_index',
                'log_income_pp',
                'Q4_2019_rent_officelogp',
                'Q4_2019_rent_residentiallogp',
                'Q4_2019_rent_retaillogp',
                'urbanization',
                'priv_alle_tech_200'
                ]
            Z_NAME_SPLIT_ORD = []
            Z_NAME_SPLIT_UNORD = []
            Z_NAME_MGATE = []
            Z_NAME_AMGATE = []
        #   Variable to be excluded from preliminary feature selection
            X_NAME_REMAIN_ORD = []
            X_NAME_REMAIN_UNORD = []
            X_BALANCE_NAME_ORD = []
            X_BALANCE_NAME_UNORD = []

            OUTPUT_TYPE = None          # 0: Output goes to terminal
        #                             1: output goes to file
        #                             2: Output goes to file and terminal (def)
            VERBOSE = True
            DESCRIPTIVE_STATS = True
            SHOW_PLOTS = True          # execute plt.show() command (def: True)

        # controls for all figures
            FONTSIZE = None
            DPI = None            # > 0: def: 500
            CI_LEVEL = None
        # Only for (A, AM) GATEs: What type of plot to use for continuous var
            NO_FILLED_PLOT = None  # use filled plot if more than xx

        # Multiprocessing
            MP_PARALLEL = None         # number of parallel processes  (>0)
        # default: Logical cores -2 (reduce if memory problems!)
        # 0, 1: no parallel computations
            MP_WITH_RAY = True  # True: Ray, False: Concurrent future
        #                           False may be faster with small samples
            MP_VIM_TYPE = None        # Variable importance: type of mp
            MP_WEIGHTS_TYPE = None    # Weights computation: type of mp
            MP_WEIGHTS_TREE_BATCH = None
            WEIGHT_AS_SPARSE = True
            BOOT = None
        # data cleaning
            SCREEN_COVARIATES = True  # True (Default): screen covariates (sc)
            CHECK_PERFECTCORR = True
            MIN_DUMMY_OBS = None      # if sc=1: dummy variable with obs in one
            CLEAN_DATA_FLAG = True
        # Estimation methods
            MCE_VART = None  # splitting rule
        #               0: mse's of regression only considered
        #               1: mse+mce criterion; (def, None)
        #               2: -var(effect): heterogy maximising splitting rule of
        #                       wager & athey (2018)
        #               3: randomly switching between outcome-mse+mce criterion
        #                   and penalty functions
        # Penalty function
            P_DIFF_PENALTY = None  # depends on mce_vart
            FS_OTHER_SAMPLE = True
            FS_OTHER_SAMPLE_SHARE = None
            FS_RF_THRESHOLD = None
        # Local centering
            L_CENTERING = True  # False: No local centering (def: True)
            L_CENTERING_NEW_SAMPLE = False  # (def: False)
            L_CENTERING_SHARE = None   # Share of data used for
            L_CENTERING_CV_K = None
        # Common support
            SUPPORT_CHECK = None
            SUPPORT_QUANTIL = None
            SUPPORT_MIN_P = None
            SUPPORT_MAX_DEL_TRAIN = 0.5
            VARIABLE_IMPORTANCE_OOB = True
            BALANCING_TEST = None
        # Truncation of extreme weights
            MAX_WEIGHT_SHARE = None
        # Subsampling
            SUBSAMPLE_FACTOR_FOREST = None
            SUBSAMPLE_FACTOR_EVAL = False
            MATCH_NN_PROG_SCORE = True   # False: use Mahalanobis matching
            NN_MAIN_DIAG_ONLY = False    # Only if match_nn_prog_score = False
            STOP_EMPTY = None
        #   randomly chosen variable did not led to a new leaf
        #   0: new variables will be drawn & splitting continues n times
        #   (faster if smaller, but nonzero); (def:1)

            SHARE_FOREST_SAMPLE = None
        #   0-1: share of sample used for predicting y given forests (def: 0.5)
        #        other sample used for building forest
            RANDOM_THRESHOLDS = None  # 0: no random thresholds
        #               > 0: number of random thresholds used for ordered var's
        # Minimum leaf size
            N_MIN_MIN = None      # smallest minimum leaf size (def: -1)
            N_MIN_MAX = None      # largest minimum leaf size (def=-1)
            N_MIN_GRID = None     # numer of grid values (def: 1)
            ALPHA_REG_MIN = None
            ALPHA_REG_MAX = None      # 0 <= alpha < 0.5 (def: 0.1)
            ALPHA_REG_GRID = None     # number of grid values (def: 1)

            M_MIN_SHARE = None
        #   minimum share of variables used for next split (0-1); def = -1
            M_MAX_SHARE = None
            M_GRID = None  # m_try
            M_RANDOM_POISSON = True
            COND_VAR_FLAG = True  # False: variance estimation uses var(wy)
            KNN_FLAG = None        # False: Nadaraya-Watson estimation
            KNN_MIN_K = None        # k: minimum number of neighbours in
        #                           k-nn estimation(def: 10)
            KNN_CONST = None        # constant in number of neighbour
            NW_BANDW = None         # bandwidth for nw estimation; multiplier
        #                           of silverman's optimal bandwidth (None: 1)
            NW_KERN_FLAG = None     # kernel for nw estimation:
        #                           1: Epanechikov (def); 2: normal
            SE_BOOT_ATE = None
            SE_BOOT_GATE = None    # False: No Bootstrap SE of effects
            SE_BOOT_IATE = None    # True: 199 bootstraps
        # if CLUSTER_STD == False: Default is False
        # if CLUSTER_STD == True default is 199; block-bootstrap is used
            MAX_CATS_Z_VARS = None  # maximum number of categories for
            PANEL_DATA = False
            PANEL_IN_RF = False
            CLUSTER_STD = False
            CHOICE_BASED_SAMPLING = False
            CHOICE_BASED_WEIGHTS = [0.9, 0.8, 0.9, 0.95]
            ATET_FLAG = False
            GATET_FLAG = False
            IATE_FLAG = True
            IATE_SE_FLAG = True

            GMATE_NO_EVALUATION_POINTS = None  # Number of evluation points for
            GMATE_SAMPLE_SHARE = None
            SMOOTH_GATES = True
            SMOOTH_GATES_BANDWIDTH = None
            SMOOTH_GATES_NO_EVALUATION_POINT = None  # (def: 50)
        # analysis of predicted values
            POST_EST_STATS = True
            RELATIVE_TO_FIRST_GROUP_ONLY = True
            BIN_CORR_YES = True
            BIN_CORR_THRESHOLD = None
            POST_PLOTS = True       # plots of estimated treatment effects
        #                             in pred_eff_data         (def: True)
            POST_KMEANS_YES = True  # using k-means clustering to analyse
        #                       patterns in the estimated effects (def: True)
            POST_KMEANS_NO_OF_GROUPS = None
        # to be build: Integer, list or tuple (or None --> default).
        # Def: List of 5 values: [a, b, c, d, e]; c = 5 to 10; depending on n;
        # c<7: a=c-2, b=c-1, d=c+1, e=c+2 else a=c-4, b=c-2, d=c+2, e=c+4
            POST_KMEANS_REPLICATIONS = None
            POST_KMEANS_MAX_TRIES = None
            POST_RANDOM_FOREST_VI = True
        # Sample splitting to reduce computational costs
            REDUCE_SPLIT_SAMPLE = False            # Default (None) is False
            REDUCE_SPLIT_SAMPLE_PRED_SHARE = None
            REDUCE_TRAINING = False
            REDUCE_TRAINING_SHARE = None
            REDUCE_PREDICTION = False
            REDUCE_PREDICTION_SHARE = None
            REDUCE_LARGEST_GROUP_TRAIN = False
            REDUCE_LARGEST_GROUP_TRAIN_SHARE = None
        # ---------------------------------------------------------------------------
            _SMALLER_SAMPLE = 0
            _MAX_CATS_CONT_VARS = None
        # number of categories for continuous variables n values < n speed up
        # programme, def: not used.
            _WITH_OUTPUT = True       # use print statements
            _MAX_SAVE_VALUES = 50
            _SEED_SAMPLE_SPLIT = 67567885
            _MP_RAY_DEL = None
            _MP_RAY_SHUTDOWN = None
            _MP_RAY_OBJSTORE_MULTIPLIER = None
        # ---------------------------------------------------------------------------
            if __name__ == '__main__':
                (ate_i, ate_se_i, _, _, _, _, _, _, _
                 ) = mcf.modified_causal_forest(
                    outpfad=OUTPATH, datpfad=DATPATH, indata=INDATA,
                    preddata=PREDDATA,
                    outfiletext=OUTFILETEXT, output_type=OUTPUT_TYPE,
                    save_forest=SAVE_FOREST, forest_files=FOREST_FILES,
                    ci_level=CI_LEVEL,
                    clean_data_flag=CLEAN_DATA_FLAG,
                    screen_covariates=SCREEN_COVARIATES,
                    min_dummy_obs=MIN_DUMMY_OBS,
                    check_perfectcorr=CHECK_PERFECTCORR,
                    panel_data=PANEL_DATA, panel_in_rf=PANEL_IN_RF,
                    weighted=WEIGHTED,
                    cluster_std=CLUSTER_STD,
                    choice_based_sampling=CHOICE_BASED_SAMPLING,
                    choice_based_weights=CHOICE_BASED_WEIGHTS,
                    match_nn_prog_score=MATCH_NN_PROG_SCORE,
                    nn_main_diag_only=NN_MAIN_DIAG_ONLY,
                    n_min_grid=N_MIN_GRID, n_min_min=N_MIN_MIN,
                    n_min_max=N_MIN_MAX,
                    m_min_share=M_MIN_SHARE, m_max_share=M_MAX_SHARE,
                    m_grid=M_GRID,
                    m_random_poisson=M_RANDOM_POISSON,
                    alpha_reg_min=ALPHA_REG_MIN,
                    alpha_reg_max=ALPHA_REG_MAX,
                    alpha_reg_grid=ALPHA_REG_GRID,
                    mce_vart=MCE_VART, p_diff_penalty=P_DIFF_PENALTY,
                    boot=BOOT,
                    knn_flag=KNN_FLAG, knn_const=KNN_CONST,
                    nw_kern_flag=NW_KERN_FLAG,
                    knn_min_k=KNN_MIN_K, cond_var_flag=COND_VAR_FLAG,
                    nw_bandw=NW_BANDW,
                    subsample_factor_forest=SUBSAMPLE_FACTOR_FOREST,
                    subsample_factor_eval=SUBSAMPLE_FACTOR_EVAL,
                    atet_flag=ATET_FLAG, gatet_flag=GATET_FLAG,
                    iate_flag=IATE_FLAG, iate_se_flag=IATE_SE_FLAG,
                    max_cats_z_vars=MAX_CATS_Z_VARS,
                    gmate_no_evaluation_points=GMATE_NO_EVALUATION_POINTS,
                    gmate_sample_share=GMATE_SAMPLE_SHARE,
                    smooth_gates=SMOOTH_GATES,
                    smooth_gates_bandwidth=SMOOTH_GATES_BANDWIDTH,
                    l_centering=L_CENTERING,
                    l_centering_share=L_CENTERING_SHARE,
                    l_centering_new_sample=L_CENTERING_NEW_SAMPLE,
                    l_centering_cv_k=L_CENTERING_CV_K, fs_yes=FS_YES,
                    fs_other_sample=FS_OTHER_SAMPLE,
                    fs_rf_threshold=FS_RF_THRESHOLD,
                    fs_other_sample_share=FS_OTHER_SAMPLE_SHARE,
                    support_min_p=SUPPORT_MIN_P, support_check=SUPPORT_CHECK,
                    support_max_del_train=SUPPORT_MAX_DEL_TRAIN,
                    support_quantil=SUPPORT_QUANTIL,
                    max_weight_share=MAX_WEIGHT_SHARE,
                    variable_importance_oob=VARIABLE_IMPORTANCE_OOB,
                    balancing_test=BALANCING_TEST,
                    post_kmeans_max_tries=POST_KMEANS_MAX_TRIES,
                    post_random_forest_vi=POST_RANDOM_FOREST_VI,
                    bin_corr_yes=BIN_CORR_YES, post_plots=POST_PLOTS,
                    post_est_stats=POST_EST_STATS,
                    post_kmeans_yes=POST_KMEANS_YES,
                    relative_to_first_group_only=RELATIVE_TO_FIRST_GROUP_ONLY,
                    bin_corr_threshold=BIN_CORR_THRESHOLD,
                    post_kmeans_no_of_groups=POST_KMEANS_NO_OF_GROUPS,
                    post_kmeans_replications=POST_KMEANS_REPLICATIONS,
                    id_name=ID_NAME, cluster_name=CLUSTER_NAME, w_name=W_NAME,
                    d_name=D_NAME, y_tree_name=Y_TREE_NAME, y_name=Y_NAME,
                    x_name_ord=X_NAME_ORD, x_name_unord=X_NAME_UNORD,
                    x_name_always_in_ord=X_NAME_ALWAYS_IN_ORD,
                    x_name_always_in_unord=X_NAME_ALWAYS_IN_UNORD,
                    x_name_remain_ord=X_NAME_REMAIN_ORD,
                    x_name_remain_unord=X_NAME_REMAIN_UNORD,
                    x_balance_name_ord=X_BALANCE_NAME_ORD,
                    x_balance_name_unord=X_BALANCE_NAME_UNORD,
                    z_name_list=Z_NAME_LIST, z_name_split_ord=Z_NAME_SPLIT_ORD,
                    z_name_split_unord=Z_NAME_SPLIT_UNORD,
                    z_name_mgate=Z_NAME_MGATE,
                    z_name_amgate=Z_NAME_AMGATE,
                    random_thresholds=RANDOM_THRESHOLDS,
                    mp_parallel=MP_PARALLEL,
                    predict_mcf=PREDICT_MCF, train_mcf=TRAIN_MCF,
                    se_boot_ate=SE_BOOT_ATE, se_boot_gate=SE_BOOT_GATE,
                    se_boot_iate=SE_BOOT_IATE,
                    reduce_split_sample=REDUCE_SPLIT_SAMPLE,
                    reduce_split_sample_pred_share=REDUCE_SPLIT_SAMPLE_PRED_SHARE,
                    reduce_training=REDUCE_TRAINING,
                    reduce_training_share=REDUCE_TRAINING_SHARE,
                    reduce_prediction=REDUCE_PREDICTION,
                    reduce_prediction_share=REDUCE_PREDICTION_SHARE,
                    reduce_largest_group_train=REDUCE_LARGEST_GROUP_TRAIN,
                    _mp_ray_objstore_multiplier=_MP_RAY_OBJSTORE_MULTIPLIER,
                    _with_output=_WITH_OUTPUT,
                    _max_save_values=_MAX_SAVE_VALUES,
                    _smaller_sample=_SMALLER_SAMPLE,
                    _seed_sample_split=_SEED_SAMPLE_SPLIT,
                    _max_cats_cont_vars=_MAX_CATS_CONT_VARS,
                    _mp_ray_del=_MP_RAY_DEL,
                    _mp_ray_shutdown=_MP_RAY_SHUTDOWN)
            ate.append(ate_i[0, 0, :])
            ate_se.append(ate_se_i[0, 0, :])
            ate_name.append(add_to_name)
OUTFILETEXT = OUTPATH + '/summary.txt'
orig_stdout = sys.stdout
gp.delete_file_if_exists(OUTFILETEXT)
sys.stdout = gp.OutputTerminalFile(OUTFILETEXT)

print()
print('=' * 80)
print('Quick Summary of results')
print('Weighted') if WEIGHTED else print('Unweighted')
print('Feature selection') if FS_YES else print('No feature selection')

print('-' * 80)
for name_i, name in enumerate(ate_name):
    print(name, 'ATE:', *ate[name_i])
    print(name, 'ATE_SE:', *ate_se[name_i])
    print('- ' * 40)
print('Outcome name                 2-1 effect   SE  3-1 effect   SE',
      ' 3-2 effect   SE')
for name_i, name in enumerate(ate_name):
    print(f'{name:25}:', end='')
    for idx in range(3):
        print(f'{ate[name_i][idx]:8.3f} ',
              f'{ate_se[name_i][idx]:8.3f} ', end='')
    print()

ATE_FILE = OUTPATH + '/ATE.csv'
ATE_SE_FILE = OUTPATH + '/ATE_SE.csv'
ate_np = np.array(ate)
ate_se_np = np.array(ate_se)
ate_pd = pd.DataFrame(data=ate_np.T, columns=ate_name)
ate_se_pd = pd.DataFrame(data=ate_se_np.T, columns=ate_name)
ate_pd.to_csv(ATE_FILE, index=False)
ate_se_pd.to_csv(ATE_SE_FILE, index=False)

sys.stdout.output.close()
sys.stdout = orig_stdout
